using CairoMakie,Distributions,KernelDensity,Printf,Random,Statistics

const M0=10000000
const M1=100000
const n_vec=[100,500,Inf]
const Y_vec=[2.5,0.05]

results = Array{Float64,2}(undef,5*length(Y_vec),3*length(n_vec))
rng = Xoshiro(2024)
mu0 = randn(rng,M0,2)
draws = Array{Float64,4}(undef,M1,length(n_vec),length(Y_vec),5)
for (i,Y) in enumerate(Y_vec)
    for (j,n) in enumerate(n_vec)
        global results
	global draws
        mu = [-Y*ones(M0,1) Y*ones(M0,1)]+mu0/sqrt(n)
        mu = mu[mu[:,1].<mu[:,2],:]
        M = size(mu,1)
        width = mu[:,2]-mu[:,1] 
	results[5*(i-1)+1,3*(j-1)+1] = mean(mu[:,1])
	results[5*(i-1)+1,3*(j-1)+2] = mean(mu[:,2])
	results[5*(i-1)+1,3*(j-1)+3] = mean(width)
	# prior 1: uniform prior
        theta = mu[:,1]+rand(rng,M).*(mu[:,2]-mu[:,1])
	results[5*(i-1)+2,3*(j-1)+1] = mean(theta)
	draws[:,j,i,1]=theta[1:M1]
	theta = sort(theta)
	results[5*(i-1)+2,3*(j-1)+2] = theta[Int64(floor(0.05*M))]
	results[5*(i-1)+2,3*(j-1)+3] = theta[Int64(floor(0.95*M))]

	# prior 2: truncated normal prior
	theta = rand.(rng,Truncated.(Normal(),mu[:,1],mu[:,2]))
	results[5*(i-1)+3,3*(j-1)+1] = mean(theta)
        draws[:,j,i,2]=theta[1:M1]
	theta = sort(theta)
	results[5*(i-1)+3,3*(j-1)+2] = theta[Int64(floor(0.05*M))]
	results[5*(i-1)+3,3*(j-1)+3] = theta[Int64(floor(0.95*M))] 

	# prior 3: Extremely 'informative' prior i
        d = Float64.(rand(rng,Bernoulli(0.99),M))
        theta = d.*mu[:,1]+(ones(M,1)-d).*mu[:,2]
	results[5*(i-1)+4,3*(j-1)+1] = mean(theta)
        draws[:,j,i,3]=theta[1:M1]
	theta = sort(theta,dims=1)
	results[5*(i-1)+4,3*(j-1)+2] = theta[Int64(floor(0.05*M))]
	results[5*(i-1)+4,3*(j-1)+3] = theta[Int64(floor(0.95*M))] 

	# prior 4: Extremely 'informative' prior ii
        d = Float64.(rand(rng,Bernoulli(0.01),M))
        theta = d.*mu[:,1]+(ones(M,1)-d).*mu[:,2]
	results[5*i,3*(j-1)+1] = mean(theta)
        draws[:,j,i,4]=theta[1:M1]
	theta = sort(theta,dims=1)
	results[5*i,3*(j-1)+2] = theta[Int64(floor(0.05*M))]
	results[5*i,3*(j-1)+3] = theta[Int64(floor(0.95*M))] 
    end
end 

io = open("watson.tex","w")
@printf(io,"\\documentclass[12pt]{article}\n")
@printf(io,"\\usepackage{rotating}\n")
@printf(io,"\\begin{document}\n")
@printf(io,"\\begin{sidewaystable}\n")
@printf(io,"Posterior distribution when \$Y_{1}=-2.5\$ and \$Y_{2}=2.5\$\n")
@printf(io,"\\vspace*{1em}\n\n")
@printf(io,"\\begin{tabular}{l|lll|lll|lll}\\hline\n")
@printf(io," & \\multicolumn{3}{|c|}{\$n=100\$} & \\multicolumn{3}{|c|}{\$n=500\$} &  \\multicolumn{3}{|c}{\$n=\\infty\$}  \\\\ \\cline{2-10}\n")
@printf(io," & \\multicolumn{2}{c}{ideantified} & mean & \\multicolumn{2}{c}{identified} & mean & \\multicolumn{2}{c}{identified} & mean \\\\ \n")
@printf(io," & \\multicolumn{2}{c}{interval} & width & \\multicolumn{2}{c}{interval} & width & \\multicolumn{2}{c}{interval} & width\\\\ \\hline \n")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[1,j])
end
@printf(io,"\\\\ \\hline \n") 
@printf(io," & mean & \\multicolumn{2}{c|}{90\\%% CI} & mean & \\multicolumn{2}{c}{90\\%% CI} & mean & \\multicolumn{2}{c}{90\\%% CI} \\\\ \\hline \n")
@printf(io,"prior 1 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[2,j])
end
@printf(io,"\\\\ \n") 
@printf(io,"prior 2 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[3,j])
end
@printf(io,"\\\\ \n") 
@printf(io,"prior 3 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[4,j])
end
@printf(io,"\\\\ \n") 
@printf(io,"prior 4 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[5,j])
end
@printf(io,"\\\\ \\hline \n") 
@printf(io,"\\end{tabular}\n\n")

@printf(io,"\\vspace*{2em}\n")
@printf(io,"Posterior distribution \$Y_{1}=-0.05\$ and \$Y_{2}=0.05\$\n")
@printf(io,"\\vspace*{1em}\n\n")
@printf(io,"\\begin{tabular}{l|lll|lll|lll}\\hline\n")
@printf(io," & \\multicolumn{3}{|c|}{\$n=100\$} & \\multicolumn{3}{|c|}{\$n=500\$} &  \\multicolumn{3}{|c}{\$n=\\infty\$}  \\\\ \\cline{2-10}\n")
@printf(io," & \\multicolumn{2}{c}{ideantified} & mean & \\multicolumn{2}{c}{identified} & mean & \\multicolumn{2}{c}{identified} & mean \\\\ \n")
@printf(io," & \\multicolumn{2}{c}{interval} & width & \\multicolumn{2}{c}{interval} & width & \\multicolumn{2}{c}{interval} & width\\\\ \\hline \n")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[6,j])
end
@printf(io,"\\\\ \\hline \n") 
@printf(io," & mean & \\multicolumn{2}{c|}{90\\%% CI} & mean & \\multicolumn{2}{c}{90\\%% CI} & mean & \\multicolumn{2}{c}{90\\%% CI} \\\\ \\hline \n")
@printf(io,"prior 1 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[7,j])
end
@printf(io,"\\\\ \n") 
@printf(io,"prior 2 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[8,j])
end
@printf(io,"\\\\ \n") 
@printf(io,"prior 3 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[9,j])
end
@printf(io,"\\\\ \n") 
@printf(io,"prior 4 ")
for j=1:size(results,2)
    @printf(io,"& %5.2f ",results[10,j])
end
@printf(io,"\\\\ \\hline \n") 
@printf(io,"\\end{tabular}\n\n")
@printf(io,"\\end{sidewaystable}\n")
@printf(io,"\\end{document}\n")
close(io)

f1 = Figure()
k1 = kde(draws[:,1,1,1])
k2 = kde(draws[:,1,1,2])
k3 = kde(draws[:,1,1,3])
k4 = kde(draws[:,1,1,4])
ax1 = Axis(f1[1,1],limits=(-3.0,3.0,0.0,nothing),title=L"$Y_{1}=-2.5$ and $Y_{2}=2.5$")
lines!(ax1,k1.x,k1.density,color=Makie.wong_colors()[1],label="uniform")
lines!(ax1,k2.x,k2.density,color=Makie.wong_colors()[3],label="truncated normal")
lines!(ax1,k3.x,k3.density,color=Makie.wong_colors()[2],label="informative 1")
lines!(ax1,k4.x,k4.density,color=Makie.wong_colors()[6],label="informative 2")
vlines!(ax1,[-2.5,2.5],color=:red,linestyle=:dot)

io = open("uniformT100Y25.txt","w")
for j=1:length(k1.x)
    @printf(io,"%6.3f %6.3f\n",k1.x[j],k1.density[j])
end
close(io)
io = open("truncatednormalT100Y25.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k2.x[j],k2.density[j])
end
close(io)
io = open("informative1T100Y25.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k3.x[j],k3.density[j])
end
close(io)
io = open("informative2T100Y25.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k4.x[j],k4.density[j])
end
close(io)

k1 = kde(draws[:,1,2,1])
k2 = kde(draws[:,1,2,2])
k3 = kde(draws[:,1,2,3])
k4 = kde(draws[:,1,2,4])
ax2 = Axis(f1[2,1],limits=(-3.0,3.0,0.0,nothing),title=L"$Y_{1}=-2.5$ and $Y_{2}=2.5$")
lines!(ax2,k1.x,k1.density,color=Makie.wong_colors()[1],label="uniform")
lines!(ax2,k2.x,k2.density,color=Makie.wong_colors()[3],label="truncated normal")
lines!(ax2,k3.x,k3.density,color=Makie.wong_colors()[2],label="informative 1")
lines!(ax2,k4.x,k4.density,color=Makie.wong_colors()[6],label="informative 2")
vlines!(ax2,[-0.05,0.05],color=:red,linestyle=:dot)
axislegend()
save("figure_watson_T100.pdf",f1)

io = open("uniformT100Y005.txt","w")
for j=1:length(k1.x)
    @printf(io,"%6.3f %6.3f\n",k1.x[j],k1.density[j])
end
close(io)
io = open("truncatednormalT100Y005.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k2.x[j],k2.density[j])
end
close(io)
io = open("informative1T100Y005.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k3.x[j],k3.density[j])
end
close(io)
io = open("informative2T100Y005.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k4.x[j],k4.density[j])
end
close(io)

f2 = Figure()
k1 = kde(draws[:,2,1,1])
k2 = kde(draws[:,2,1,2])
k3 = kde(draws[:,2,1,3])
k4 = kde(draws[:,2,1,4])

ax1 = Axis(f2[1,1],limits=(-3.0,3.0,0.0,nothing),title=L"$Y_{1}=-2.5$ and $Y_{2}=2.5$")
lines!(ax1,k1.x,k1.density,color=Makie.wong_colors()[1],label="uniform")
lines!(ax1,k2.x,k2.density,color=Makie.wong_colors()[3],label="truncated normal")
lines!(ax1,k3.x,k3.density,color=Makie.wong_colors()[2],label="informative 1")
lines!(ax1,k4.x,k4.density,color=Makie.wong_colors()[6],label="informative 2")
vlines!(ax1,[-2.5,2.5],color=:red,linestyle=:dot)

io = open("uniformT500Y25.txt","w")
for j=1:length(k1.x)
    @printf(io,"%6.3f %6.3f\n",k1.x[j],k1.density[j])
end
close(io)
io = open("truncatednormalT500Y25.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k2.x[j],k2.density[j])
end
close(io)
io = open("informative1T500Y25.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k3.x[j],k3.density[j])
end
close(io)
io = open("informative2T500Y25.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k4.x[j],k4.density[j])
end
close(io)

k1 = kde(draws[:,2,2,1])
k2 = kde(draws[:,2,2,2])
k3 = kde(draws[:,2,2,3])
k4 = kde(draws[:,2,2,4])
ax2 = Axis(f2[2,1],limits=(-3.0,3.0,0.0,nothing),title=L"$Y_{1}=-0.05$ and $Y_{2}=0.05$")
lines!(ax2,k1.x,k1.density,color=Makie.wong_colors()[1],label="uniform")
lines!(ax2,k2.x,k2.density,color=Makie.wong_colors()[3],label="truncated normal")
lines!(ax2,k3.x,k3.density,color=Makie.wong_colors()[2],label="informative 1")
lines!(ax2,k4.x,k4.density,color=Makie.wong_colors()[6],label="informative 2")
vlines!(ax2,[-0.05,0.05],color=:red,linestyle=:dot)
axislegend()

save("figure_watson_T500.pdf",f2)

io = open("uniformT500Y005.txt","w")
for j=1:length(k1.x)
    @printf(io,"%6.3f %6.3f\n",k1.x[j],k1.density[j])
end
close(io)
io = open("truncatednormalT500Y005.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k2.x[j],k2.density[j])
end
close(io)
io = open("informative1T500Y005.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k3.x[j],k3.density[j])
end
close(io)
io = open("informative2T500Y005.txt","w")
for j=1:length(k2.x)
    @printf(io,"%6.3f %6.3f\n",k4.x[j],k4.density[j])
end
close(io)

io = open("uniformTinftyY25.txt","w")
for x in -3.0:0.05:3.0
    if abs(x)<=Y_vec[1]
        @printf(io,"%6.3f %6.3f\n",x,1.0/(2*Y_vec[1]))
    else
        @printf(io,"%6.3f %6.3f\n",x,0.0)
    end
end
close(io)
io = open("truncatednormalTinftyY25.txt","w")
for x in -3.0:0.05:3.0
    if abs(x)<=Y_vec[1]
        @printf(io,"%6.3f %6.3f\n",x,pdf(Normal(),x)/(cdf(Normal(),Y_vec[1])-cdf(Normal(),-Y_vec[1])))
    else
        @printf(io,"%6.3f %6.3f\n",x,0.0)
    end
end
close(io)

io = open("uniformTinftyY005.txt","w")
for x in -3.0:0.05:3.0
    if abs(x)<=Y_vec[2]
        @printf(io,"%6.3f %6.3f\n",x,1.0/(2*Y_vec[2]))
    else
        @printf(io,"%6.3f %6.3f\n",x,0.0)
    end
end
close(io)

io = open("truncatednormalTinftyY005.txt","w")
for x in -3.0:0.05:3.0
    if abs(x)<=Y_vec[2]
        @printf(io,"%6.3f %6.3f\n",x,pdf(Normal(),x)/(cdf(Normal(),Y_vec[2])-cdf(Normal(),-Y_vec[2])))
    else
        @printf(io,"%6.3f %6.3f\n",x,0.0)
    end
end
close(io)

